{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "aff425fb-9ebd-49ce-ad33-13c92f8c19cd",
   "metadata": {},
   "source": [
    "# Flux Playbook"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "c5c4601c-f00d-4d7b-b56b-8eeca09d4f95",
   "metadata": {},
   "source": [
    "### Note:\n",
    "The tutorial is supposed to work in a NeMo container (> 24.09). We provide the basic usage of Flux training and inference pipeline as an example. Please note that the full Flux model contains 12 billion parameter and require certain VRAM to run the inference script it in full size. \n",
    "\n",
    "Important: The Flux checkpoint from Huggingface requires per-user authentication to get access. Please set your own HF token with proper access before running the inference section of this notebook, otherwise, the model will be randomly initialized and therefore, output images will be random noise."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5035f643-f52e-4559-9d47-067936351034",
   "metadata": {},
   "source": [
    "##### Launch a NeMo docker container \n",
    "```\n",
    "docker run --gpus all -it --rm -v <your_nemo_dir>:/opt/NeMo --shm-size=8g \\\n",
    "     -p 8888:8888 --ulimit memlock=-1 --ulimit \\\n",
    "      stack=67108864 nvcr.io/nvidia/nemo:xx.xx\n",
    "```\n",
    "Mounting your own version of NeMo repo is optional, it's only needed when you have customized changes outside this notebook for testing purpose."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1cd60981-0177-4b1c-b9d8-4cfac3fbe5cf",
   "metadata": {},
   "source": [
    "### Flux Training with Mock Dataset\n",
    "\n",
    "For illustration purpose, we first take a look at how to run the pre-defined unit test recipe where number of transformer layers of Flux is set to 1. In this recipe, all modules are initialized randomly so no pre-downloaded checkpoint is needed. We also provide a mock data module which generates image and text embeds directly, so text and image encoders are not required.\n",
    "\n",
    "Let's take a look at the configs in this recipe."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "7ee06833-b1d8-4e7c-b2f5-c168ba8014c0",
   "metadata": {},
   "source": [
    "```\n",
    "@run.cli.factory(target=llm.train)\n",
    "def unit_test() -> run.Partial:\n",
    "    '''\n",
    "    Basic functional test, with mock dataset,\n",
    "    text/vae encoders not initialized, ddp strategy,\n",
    "    frozen and trainable layers both set to 1\n",
    "    '''\n",
    "    recipe = flux_training()\n",
    "\n",
    "    # Set params of following modules to Null when image and text provided in the datamodule are embeddings\n",
    "    recipe.model.flux_params.t5_params = None \n",
    "    recipe.model.flux_params.clip_params = None\n",
    "    recipe.model.flux_params.vae_config = None\n",
    "    recipe.model.flux_params.device = 'cuda'\n",
    "\n",
    "    # Set number of layers of Flux\n",
    "    recipe.model.flux_params.flux_config = run.Config(\n",
    "        FluxConfig,\n",
    "        num_joint_layers=1,\n",
    "        num_single_layers=1,\n",
    "    )\n",
    "\n",
    "    recipe.data.global_batch_size = 1\n",
    "    recipe.trainer.strategy.ddp = run.Config(\n",
    "        DistributedDataParallelConfig,\n",
    "        check_for_nan_in_grad=True,\n",
    "        grad_reduce_in_fp32=True,\n",
    "    )\n",
    "    recipe.trainer.max_steps=10\n",
    "    return recipe\n",
    "```"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "fc498346-8ebe-4cfc-a051-55fb85f5f931",
   "metadata": {},
   "source": [
    "In NeMo-2, such pre-defined recipe can work easily as following:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "df4fc115-4ac3-4190-b803-619f7c060a52",
   "metadata": {},
   "outputs": [],
   "source": [
    "!torchrun /opt/NeMo/scripts/flux/flux_training.py --yes --factory unit_test"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "80557854-c85d-499a-abda-68ad5bcbdadb",
   "metadata": {},
   "source": [
    "To keep the playbook simple, we use the least number of layers above. You can change the config in pre-defined recipes to test locally with different number of layers, number of devices, etc. We also provdied other pre-defined recipes in the script for reference.."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d2c7a0a6-d104-4f34-9554-478d71b53e13",
   "metadata": {},
   "source": [
    "### Flux Inference\n",
    "From this point, please download the [Flux-1.dev checkpoint][flux] from HF and save it locally before proceeding, or set your own Hugging Face token with proper access to download it automatically. Otherwise, the notebook will just run randomly initialized dummy model and the results will be just for illustration because it will be pure noise!\n",
    "\n",
    "\n",
    "**Troubleshooting:**\n",
    "- If you encounter \"safetensors header too large\" errors, the model files were not fully downloaded\n",
    "- Verify file sizes: `ae.safetensors` should be several GB, not just a few KB\n",
    "- Use `file /temp/FLUX.1-dev/ae.safetensors` to check if it's a proper safetensors file\n",
    "\n",
    "[flux]: https://huggingface.co/black-forest-labs/FLUX.1-dev\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0d85ee40",
   "metadata": {},
   "outputs": [],
   "source": [
    "####  Download checkpoints using Hugging Face Hub\n",
    "from huggingface_hub import snapshot_download\n",
    "\n",
    "# Download the entire model repository\n",
    "# Replace <HF_token> with your actual Hugging Face token\n",
    "HF_TOKEN = \"<HF_token>\"  # Replace with your token\n",
    "model_path = snapshot_download(\n",
    "    repo_id=\"black-forest-labs/FLUX.1-dev\",\n",
    "    token=HF_TOKEN,\n",
    "    local_dir=\"/temp/FLUX.1-dev\",\n",
    "    local_dir_use_symlinks=False\n",
    ")\n",
    "\n",
    "print(f\"Model downloaded to: {model_path}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd318155",
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "# Verify the download\n",
    "!ls -la /temp/FLUX.1-dev/\n",
    "!file /temp/FLUX.1-dev/ae.safetensors\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "22ff84ad-f080-4d39-b109-5f065f76c0e4",
   "metadata": {},
   "source": [
    "When you have downloaded the checkpoint, specify the path below and run follows.\n",
    "Note that this model contains 12B parameters, it requires significant RAM in GPU or it runs Out Of Memory"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ef0e35a1-efd5-4d98-bf6b-03311c2a2977",
   "metadata": {},
   "outputs": [],
   "source": [
    "#### Optional Cell, only makes sense if your machine has enough device memory and you downloaded valid checkpoint from previous steps\n",
    "CHECKPOINT_PATH=\"/temp/FLUX.1-dev\"\n",
    "\n",
    "# Verify the checkpoint path exists and contains the required files\n",
    "!ls -la ${CHECKPOINT_PATH}/\n",
    "!ls -la ${CHECKPOINT_PATH}/ae.safetensors\n",
    "\n",
    "# Run inference with the downloaded checkpoint\n",
    "!torchrun /opt/NeMo/scripts/flux/flux_infer.py \\\n",
    "  --flux_ckpt ${CHECKPOINT_PATH}/transformer \\\n",
    "  --clip_version ${CHECKPOINT_PATH}/text_encoder \\\n",
    "  --t5_version ${CHECKPOINT_PATH}/text_encoder_2 \\\n",
    "  --vae_ckpt ${CHECKPOINT_PATH}/ae.safetensors \\\n",
    "  --do_convert_from_hf \\\n",
    "  --prompts \"A cat holding a sign that says hello world\" \\\n",
    "  --inference_steps 30"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "97c00947-d758-47a3-b169-726cd31785f8",
   "metadata": {},
   "source": [
    "For test purpose, load random weights only and reduce the number of layers to avoid OOM, the output will be just noise in this case."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "6eaf7b8b-93ea-4209-b333-121f21ab20b6",
   "metadata": {},
   "outputs": [],
   "source": [
    "!torchrun /opt/NeMo/scripts/flux/flux_infer.py --clip_version None --t5_version None --vae_ckpt None --num_joint_layers 4 --num_single_layers 8 --prompts  \"A cat holding a sign that says hello world\" --inference_steps 30"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "35c895bb-b549-4316-9aa7-2e4d0ab442f2",
   "metadata": {},
   "source": [
    "# "
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.12"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
